package com.sfzd5.StudyJavaCV.mnist;

import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.javacpp.indexer.UByteIndexer;
import org.bytedeco.javacpp.opencv_core.*;
import org.opencv.core.CvType;

import java.awt.image.BufferedImage;
import java.io.BufferedInputStream;
import java.io.FileInputStream;
import java.io.IOException;

import static org.bytedeco.javacpp.opencv_core.CV_8UC1;
import static org.bytedeco.javacpp.opencv_highgui.imshow;
import static org.bytedeco.javacpp.opencv_highgui.waitKey;


public class MnistRead {

    public static final String TRAIN_IMAGES_FILE = "MNIST_data/train-images.idx3-ubyte";
    public static final String TRAIN_LABELS_FILE = "MNIST_data/train-labels.idx1-ubyte";
    public static final String TEST_IMAGES_FILE = "MNIST_data/t10k-images.idx3-ubyte";
    public static final String TEST_LABELS_FILE = "MNIST_data/t10k-labels.idx1-ubyte";


    /**
     * change bytes into a hex string.
     *
     * @param bytes bytes
     * @return the returned hex string
     */
    public static String bytesToHex(byte[] bytes) {
        StringBuffer sb = new StringBuffer();
        for (int i = 0; i < bytes.length; i++) {
            String hex = Integer.toHexString(bytes[i] & 0xFF);
            if (hex.length() < 2) {
                sb.append(0);
            }
            sb.append(hex);
        }
        return sb.toString();
    }


    /**
     * 生成训练数据
     *
     * @param fileName the file of 'train' or 'test' about image
     * @return one row show a `picture`
     */
    public static Mat getTrainData(String fileName) {
        Mat x = null;
        try (BufferedInputStream bin = new BufferedInputStream(new FileInputStream(fileName))) {
            byte[] bytes = new byte[4];
            bin.read(bytes, 0, 4);
            if (!"00000803".equals(bytesToHex(bytes))) {
                // 读取魔数
                throw new RuntimeException("Please select the correct file!");
            } else {

                bin.read(bytes, 0, 4);
                // 读取样本总数
                int number = Integer.parseInt(bytesToHex(bytes), 16);
                bin.read(bytes, 0, 4);
                // 读取每行所含像素点数
                int xPixel = Integer.parseInt(bytesToHex(bytes), 16);
                bin.read(bytes, 0, 4);
                // 读取每列所含像素点数
                int yPixel = Integer.parseInt(bytesToHex(bytes), 16);

                int l = xPixel*yPixel;
                x = new Mat(number, l, CvType.CV_32FC1);
                FloatIndexer indexer = x.createIndexer();
                for (int i = 0; i < number; i++) {
                    for(int j=0; j<l; j++){
                        indexer.put(i, j, bin.read());
                    }
                }
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        return x;
    }

    /**
     * 获取mnist图片
     *
     * @param fileName the file of 'train' or 'test' about image
     * @return one row show a `picture`
     */
    public static Mat[] getImages(String fileName) {
        Mat[] x = null;
        try (BufferedInputStream bin = new BufferedInputStream(new FileInputStream(fileName))) {
            byte[] bytes = new byte[4];
            bin.read(bytes, 0, 4);
            if (!"00000803".equals(bytesToHex(bytes))) {                        // 读取魔数
                throw new RuntimeException("Please select the correct file!");
            } else {
                bin.read(bytes, 0, 4);
                int number = Integer.parseInt(bytesToHex(bytes), 16);           // 读取样本总数
                bin.read(bytes, 0, 4);
                int xPixel = Integer.parseInt(bytesToHex(bytes), 16);           // 读取每行所含像素点数
                bin.read(bytes, 0, 4);
                int yPixel = Integer.parseInt(bytesToHex(bytes), 16);           // 读取每列所含像素点数
                x = new Mat[number];
                for (int i = 0; i < number; i++) {
                    Mat m = new Mat(xPixel, yPixel, CV_8UC1);
                    UByteIndexer indexer = m.createIndexer();
                    for(int j=0; j<xPixel*yPixel; j++){
                        indexer.put(j, (byte)bin.read());
                    }
                    x[i] = m;
                }
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        return x;
    }

    /**
     * 获取训练的标签
     * 格式要求，每个标签为一个 float[10]数组，放在Mat的一行中
     * @param fileName
     * @return
     */
    public static Mat getTrainLabels(String fileName) {
        byte[] data = getLabels(fileName);
        Mat x = new Mat(data.length, 10, CvType.CV_32FC1);
        FloatIndexer indexer = x.createIndexer();
        for(int i=0; i<data.length; i++){
            byte b = (byte) data[i];
            for(int j=0; j<10; j++){
                if(j==b)
                    indexer.put(i, j, 1);
                else
                    indexer.put(i, j, 0);
            }
        }
        return x;
    }

    /**
     * 获取所有标签的数值
     *
     * @param fileName the file of 'train' or 'test' about label
     * @return
     */
    public static byte[] getLabels(String fileName) {
        byte[] y = null;
        try (BufferedInputStream bin = new BufferedInputStream(new FileInputStream(fileName))) {
            byte[] bytes = new byte[4];
            bin.read(bytes, 0, 4);
            if (!"00000801".equals(bytesToHex(bytes))) {
                throw new RuntimeException("Please select the correct file!");
            } else {
                bin.read(bytes, 0, 4);
                int number = Integer.parseInt(bytesToHex(bytes), 16);
                y = new byte[number];

                byte c;
                for (int i = 0; i < number; i++) {
                    c = (byte) bin.read();
                    y[i] = c;
                }
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        return y;
    }


    public static BufferedImage drawGrayPicture(int[] pixelValues, int width, int high, String fileName) throws IOException {
        BufferedImage bufferedImage = new BufferedImage(width, high, BufferedImage.TYPE_INT_RGB);
        for (int i = 0; i < width; i++) {
            for (int j = 0; j < high; j++) {
                int pixel = 255 - pixelValues[i * high + j];
                int value = pixel + (pixel << 8) + (pixel << 16);   // r = g = b 时，正好为灰度
                bufferedImage.setRGB(j, i, value);
            }
        }
        return bufferedImage;
    }


    public static void main(String[] args) {
        Mat[] images = getImages(TRAIN_IMAGES_FILE);
        byte[] labels = getLabels(TRAIN_LABELS_FILE);

        System.out.println(String.valueOf(labels[0]));
        imshow(String.valueOf(labels[0]), images[0]);
        waitKey();

        //double[][] images = getImages(TEST_IMAGES_FILE);
        //double[] labels = getLabels(TEST_LABELS_FILE);

        System.out.println();
    }
}